Skip to content

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759

Open
digantdesai wants to merge 6 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats
Open

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
digantdesai wants to merge 6 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats

Conversation

@digantdesai
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai commented Apr 8, 2026

Performance Improvements for SDPA

Improves SDPA performance for decode sequences where $L_q = 1$.

Benchmark: qwen3.5-35B-A3B

  • Metric: Decode tokens/sec
  • Config: generate=1024 tokens, median of 3 runs on A100.
Prompt tokens Baseline Split-K Original Split-K speedup
1 62.2 89.4 68.0 1.44x
16 61.5 89.1 67.8 1.45x
32 61.8 89.5 67.5 1.45x
64 61.4 89.7 68.1 1.46x
128 61.0 89.2 1.46x
256 61.4 89.0 1.45x
512 61.4 88.8 1.45x

(~25x speedup at the SDPA op level, for ~10.2K = 1024 tokens x 10 layers, calls we saw 5.3sec to 209ms speedup)

Implementation Details

  • Max Context Length: 4K
  • Kernel Constraints:
    • Original: Capped at 64 tokens for prefill due to kernel tracing limitations from smaller example inputs.
    • Baseline: Updates example input shapes to remove the 64-token cap.
    • Prefill: Baseline and Split-K are equivalent for prefill (both use _sdpa_fwd_kernel_m64).

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 3 Unrelated Failures

As of commit 62428be with merge base 930ecfd (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 2cb04c3 to febc419 Compare April 8, 2026 04:12
@digantdesai digantdesai changed the title [aoti-cuda] Add SDPA benchmarking script with qwen-3.5-35B-A3B shapes SDPA decode perf improvements for qwen-3.5-35B-A3B Apr 9, 2026
@digantdesai digantdesai marked this pull request as ready for review April 9, 2026 17:44
@digantdesai digantdesai requested a review from lucylq as a code owner April 9, 2026 17:44
Copilot AI review requested due to automatic review settings April 9, 2026 17:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves ExecuTorch CUDA SDPA decode performance for the common decode case where Lq = 1 (e.g., Qwen3.5 MoE generation), by introducing a Split-K “flash-decoding” Triton path and dispatching to it at runtime.

Changes:

  • Add a Split-K decode SDPA Triton kernel (sdpa_decode_splitk) plus a reduction kernel to improve occupancy when L_q == 1.
  • Update the Qwen3.5 MoE attention path to dispatch between Split-K (decode) and tiled SDPA (prefill) via torch.cond.
  • Add correctness tests and a benchmark script for SDPA decode shapes; update export example shapes to avoid overly-small AOTI shape specialization.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
examples/models/qwen3_5_moe/model.py Switch attention to Triton SDPA and add decode-time Split-K dispatch via torch.cond.
examples/models/qwen3_5_moe/main.cpp Plumb a stats callback into generation and print throughput/timing breakdown.
examples/models/qwen3_5_moe/export.py Use a max-length example sequence to prevent AOTI from baking in too-small intermediate buffers.
backends/cuda/triton/kernels/sdpa.py Implement Split-K decode kernel + reduction and expose sdpa_decode_splitk.
backends/cuda/triton/kernels/init.py Export sdpa_decode_splitk from the kernels package.
backends/cuda/tests/test_triton_sdpa_splitk.py Add CUDA BF16 unit tests validating Split-K correctness vs PyTorch SDPA reference.
backends/cuda/benchmarks/benchmark_sdpa.py Add a benchmark script comparing Triton SDPA/Split-K vs PyTorch SDPA backends.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1369 to +1390
@triton_op("triton::sdpa_decode_splitk", mutates_args={})
def sdpa_decode_splitk(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
) -> torch.Tensor:
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
"""
B, H_q, L_q, D = query.shape
_, H_kv, L_kv, _ = key.shape

out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)

Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sdpa_decode_splitk() launches kernels that assume CUDA + bfloat16 inputs (and the reduce kernel stores bfloat16 unconditionally), but unlike sdpa() it never calls _validate_sdpa_inputs(). This can lead to hard-to-debug failures or incorrect behavior if the op is called with non-bfloat16 or non-CUDA tensors. Consider reusing _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) at the top of sdpa_decode_splitk (and keep the existing feature/shape checks).

Copilot uses AI. Check for mistakes.
Comment on lines +1339 to +1343
HAS_MASK=HAS_MASK,
HEAD_DIM=D,
NUM_GROUPS=num_groups,
BLOCK_G=_next_power_of_2(num_groups),
)
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCK_G is derived via _next_power_of_2(num_groups), but _next_power_of_2() is documented as clamping to [16, 256] for head-dim tiling. Using it for num_groups forces a minimum of 16 (wasted work for MHA/low-group configs) and would be incorrect if num_groups > 256 (since BLOCK_G would clamp to 256 and not cover all groups). Consider a separate next_pow2 helper for group counts (no min clamp) and/or add a clear runtime check that num_groups <= 256 if that’s an intended constraint.

Copilot uses AI. Check for mistakes.
device_name = torch.cuda.get_device_name()
print()
print("=" * 100)
print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark header prints "Qwen3.5-35B-A4B", but this benchmark (and the PR description) refers to the A3B variant. This looks like a typo and can confuse readers when comparing numbers; consider correcting the printed model name.

Suggested change
print(f"SDPA Benchmark Qwen3.5-35B-A4B{scenario}")
print(f"SDPA Benchmark Qwen3.5-35B-A3B{scenario}")

Copilot uses AI. Check for mistakes.
@digantdesai
Copy link
Copy Markdown
Contributor Author

@claude review this PR.

@claude
Copy link
Copy Markdown

claude bot commented Apr 9, 2026

Claude finished @digantdesai's task in 2m 21s —— View job


PR Review: SDPA decode perf improvements for qwen-3.5-35B-A3B

  • Read all changed files
  • Review Split-K decode kernel implementation (sdpa.py)
  • Review tests (test_triton_sdpa_splitk.py)
  • Review benchmark script (benchmark_sdpa.py)
  • Review model/export changes (model.py, export.py, main.cpp)
  • Post review feedback

Overall this is a well-structured PR. The Split-K flash-decoding approach is the right strategy for L_q=1 with low H_kv — the occupancy problem is real and the 1.45x end-to-end speedup is meaningful. The kernel implementation, torch.cond dispatch, and test coverage are solid. A few issues below, roughly ordered by severity.


Issues

1. Missing input validation in sdpa_decode_splitksdpa.py:1386-1430

sdpa() calls _validate_sdpa_inputs() (checking CUDA device, bf16 dtype, 4D shape, dropout=0) before any kernel work. sdpa_decode_splitk() does its own partial validation (dropout, is_causal, L_q, H_q divisibility, pow2 D) but never checks that tensors are CUDA or bf16. Passing float32 or CPU tensors would reach the Triton kernel and produce silent wrong results or cryptic errors.

Call _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa=True) at the top, then keep the Split-K-specific checks (L_q==1, pow2 D) afterward. The existing dropout_p != 0 check would then become redundant and can be removed.

Fix this →

2. _next_power_of_2 used for BLOCK_G has wrong domainsdpa.py:1342

_next_power_of_2() clamps to [16, 256], designed for HEAD_DIM tiling. When used for BLOCK_G = _next_power_of_2(num_groups), MHA configs (num_groups=1) get BLOCK_G=16, which means the split-K kernel allocates 16-wide vectors/accumulators but only uses 1 element, wasting 15/16 of register resources and compute. For the Qwen3.5 config (num_groups=8), BLOCK_G=16 — still 50% waste.

This doesn't cause incorrect results (the g_valid mask handles it), but it's a missed optimization opportunity. A separate helper without the min=16 clamp would give BLOCK_G=1 for MHA, BLOCK_G=8 for Qwen3.5, etc.

Fix this →

3. Reduce kernel output addressing assumes B=1sdpa.py:1259-1262

The reduce kernel uses pid * stride_oh to compute the output pointer, where grid_reduce = (B * H_q,). For B>1, the output is [B, H_q, 1, D] and stride_oh = 1 * D (for contiguous layout). The correct offset for pid = b * H_q + h_q is b * stride_ob + h_q * stride_oh + 0 * stride_om. Since stride_ob = H_q * 1 * D = H_q * stride_oh, writing pid * stride_oh is equivalent to (b * H_q + h_q) * stride_oh = b * stride_ob + h_q * stride_oh, which is correct only when stride_ob == H_q * stride_oh (i.e., the tensor is contiguous in the B dimension). This holds for freshly-allocated contiguous tensors, but it's fragile and the comment at line 1259 should note this assumption more explicitly.

The same pattern applies to the partial buffers — pid * stride_op_h assumes contiguous layout of O_partial[splits, B, H_q, D]. Since these are allocated in _launch_decode_splitk and guaranteed contiguous, this is safe, but worth a comment.

4. Benchmark typo: "A4B" should be "A3B"benchmark_sdpa.py:181

print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")

Should be "A3B" to match the model variant. Copilot flagged this too.

Fix this →


Suggestions (non-blocking)

5. Partial accumulator stores unscaled acc rather than acc / l_isdpa.py:1186-1204

The split-K kernel stores the raw unnormalized accumulator acc and separate m_i / l_i. The reduce kernel then rescales and normalizes. This is the standard flash-decoding approach and is correct. However, storing acc without the l_i normalization means partial outputs can have very large magnitudes (especially with long chunks), which might lose precision in float32 for extreme cases. Not a problem for the 4K max context here, but worth keeping in mind if this is extended to longer contexts.

6. num_splits heuristicsdpa.py:1283

num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)

The choice of 256 tokens per split and max 128 splits is reasonable for A100 with ~108 SMs, but this could benefit from being SM-count-aware (e.g., torch.cuda.get_device_properties().multi_processor_count). For smaller GPUs (e.g., A10 with 72 SMs), 128 splits may over-partition. Non-blocking since the target is A100.

7. Test tolerancetest_triton_sdpa_splitk.py

All tests use max_abs_error < 0.05, which is reasonable for bf16 SDPA (expected ~1e-2 to 3e-2 range). The cross-validation test against the tiled SDPA kernel (test_cross_validate_with_sdpa) also uses 0.05 — this is good since both paths accumulate differently.

8. torch.cond dispatch in modelmodel.py:276-281

The torch.cond dispatch is clean:

y = torch.cond(
    q.shape[2] == 1,
    lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
    lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
    [q, k, v, attn_mask],
)

Note that during AOTI tracing, both branches are traced with symbolic shapes, so sdpa_decode_splitk correctly guards its L_q==1 assertion with isinstance(L_q, int) (line 1404). This is a good pattern.

9. Export shape changeexport.py:381-383

Changing example_seq_len to config.max_seq_len - 1 is the right fix. The previous smaller example caused AOTI to specialize intermediate buffers for small sizes, capping effective prefill at 64 tokens. The comment explaining this is helpful.


Summary

The kernel implementation is correct and well-tested. The main actionable item is issue #1 (missing CUDA/bf16 validation in sdpa_decode_splitk), which is a real gap that could cause confusing failures. Issue #2 (BLOCK_G over-allocation) is a perf optimization opportunity. Issue #4 is a trivial typo fix.


Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math
backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench
for timing. Standalone script, no changes to the kernel.
Add a stats_callback to generate() that prints prefill/decode rates,
model load time, TTFT, and sampling time via printf, mirroring the
format in extension/llm/runner/stats.h print_report.

Uses printf instead of ET_LOG(Info) because the CMake target does not
link executorch_no_prim_ops (which provides the PAL logger); adding
that dependency pulls in the full runtime and breaks the minimal
runner build.
Register `triton::sdpa_decode_splitk` as an independent op so AOTI
can trace and compile it without the runtime L_kv conditional that
prevents the split-K path from appearing in the standard `sdpa` op.

The split-K (flash-decoding) approach partitions the KV sequence
across CTAs and reduces partial softmax results in a second kernel.
The benchmark script now includes the split-K column for comparison.

BLOCK_G (the GQA group tile) uses _next_power_of_2_unclamped() to
avoid inflating small group counts to 16. Phantom rows from
over-sized tiles change register pressure and instruction scheduling,
altering fp32 accumulation order enough to degrade output quality
over long autoregressive sequences.

Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16,
H_kv=2, D=256, bf16):

  Lk       ET Tiled (us)  ET Split-K (us)  Speedup
  64            131.8          259.5         0.5x
  512            98.9          221.5         0.4x
  4096          199.9          214.4         0.9x
  8192          392.2          211.3         1.9x
  16384         775.3          211.8         3.7x

Split-K breaks even around Lk=4096 and dominates at longer sequences
where the tiled kernel's single-CTA-per-head bottleneck becomes severe.
The previous example used T=2, which caused AOTI to compile the
chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime,
prompts longer than 64 tokens (requiring NT>1 chunks) failed with
"Error resizing tensor at input 0". Using max_seq_len-1 as the
example ensures AOTI generalizes intermediate buffer sizes for the
full sequence length range.

Comparison against original export (tq4_sdpa fused kernel)
on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median):

                Original (tq4_sdpa)  Baseline (Triton SDPA)
  Decode tok/s       68.4               61.7
  Prefill tok/s     275.7              378.2

Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused
decode kernel is faster than the tiled Triton SDPA at L_q=1). The
split-K commit addresses the decode gap.
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding
for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard
sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing
with symbolic shapes doesn't trip the L_q==1 check.

Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal,
enable_gqa) for drop-in use with torch.cond; unsupported args fail
with clear messages.

End-to-end on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096,
1024 decode tokens, prompt="Hi", temperature=0, 5 runs median):

                Baseline (tiled)    Split-K     Speedup
  Decode tok/s         61.7          89.9        1.46x
  Prefill tok/s       378.2         378.2        1.00x

  nsys GPU time     13853 ms        8674 ms      1.60x
  SDPA kernel      5370 ms (38.8%)  209 ms (2.4%) 25.7x
@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from ebe61e8 to 5d3b620 Compare April 10, 2026 04:37
Import ordering, line-length wrapping, and missing blank lines
flagged by CI lintrunner.
Copilot AI review requested due to automatic review settings April 10, 2026 04:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +90 to +93
printf(
"\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
stats.num_prompt_tokens,
stats.num_generated_tokens);
Comment on lines +240 to +248
for name, label, _ in backends:
if name == ref_name or outputs[name] is None:
continue
err = _max_abs_error(outputs[name], ref_out)
assert err < 1e-2, (
f"Output mismatch for {_shape_label(shape)}: "
f"{label} vs {BACKENDS[ref_name][0]}, "
f"max abs error {err:.3e} >= 1e-2"
)
out = self.splitk(q, k, v, attn_mask=mask)

self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN")
self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants